%load_ext autoreload
%autoreload 2
from scip_workflows.common import *
import pickle
import anndata
import scanpy
import shap
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ConnectionStyle
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split
from scip_workflows.core import plot_gate_czi
shap.initjs()
plt.rcParams["figure.dpi"] = 200
try:
adata = snakemake.input.adata
output_three = snakemake.output[0]
output_cd15_cd45 = snakemake.output[1]
output_cd15_siglec8 = snakemake.output[2]
output_unclassified = snakemake.output[3]
image_root = snakemake.input.image_root
except NameError:
image_root = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800")
data_dir = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800/scip/061020221736/")
adata = data_dir / "adata.pickle"
output_three = data_dir / "figures" / "cluster_panels.png"
output_cd15_cd45 = data_dir / "figures" / "cd15_vs_cd45_facets.png"
output_cd15_siglec8 = data_dir / "figures" / "cd15_vs_siglec8_facets.png"
output_unclassified = data_dir / "figures" / "unclassified_cluster.png"
def map_names(a):
return {
"feat_combined_sum_DAPI": "DAPI",
"feat_combined_sum_EGFP": "CD45",
"feat_combined_sum_RPe": "Siglec 8",
"feat_combined_sum_APC": "CD15",
}[a]
with open(adata, "rb") as fh:
adata = pickle.load(fh)
adata.obs.meta_path = adata.obs.meta_path.apply(
lambda p: image_root.joinpath(*Path(p).parts[Path(p).parts.index("800") + 1 :])
)
markers = [
col
for col in adata.var_names
if col.startswith(
tuple("feat_combined_sum_%s" % m for m in ("EGFP", "RPe", "APC", "DAPI"))
)
]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
adata,
markers,
groupby="leiden",
dendrogram=True,
vmin=-2,
vmax=2,
cmap="RdBu_r",
ax=axes[0],
show=False,
use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden", legend_loc="on data", ax=axes[1], show=False)
seaborn.countplot(data=adata.obs, x="leiden", hue="meta_replicate", ax=axes[2])
WARNING: dendrogram data not found (using key=dendrogram_leiden). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
<AxesSubplot:xlabel='leiden', ylabel='count'>
adata.obs["leiden_merged"] = adata.obs.leiden.map(
lambda a: a if a in [str(i) for i in [2, 4, 6, 8]] else "1"
)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
adata,
markers,
groupby="leiden_merged",
dendrogram=True,
vmin=-2,
vmax=2,
cmap="RdBu_r",
ax=axes[1],
show=False,
use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden_merged", ax=axes[2], show=False)
seaborn.countplot(data=adata.obs, x="leiden_merged", hue="meta_replicate", ax=axes[0])
WARNING: dendrogram data not found (using key=dendrogram_leiden_merged). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
<AxesSubplot:xlabel='leiden_merged', ylabel='count'>
scanpy.pl.scatter(
adata,
x="feat_combined_sum_EGFP",
y="feat_combined_sum_APC",
color="leiden_merged",
legend_loc="on data",
)
grid = seaborn.FacetGrid(
data=scanpy.get.obs_df(
adata,
keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC", "leiden_merged"],
use_raw=True,
),
col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
seaborn.scatterplot(
data=scanpy.get.obs_df(
adata,
keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC"],
use_raw=True,
),
x="feat_combined_sum_EGFP",
y="feat_combined_sum_APC",
color="grey",
s=0.5,
alpha=0.5,
ax=ax,
)
grid.map_dataframe(
seaborn.scatterplot, x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
ax.set_yticks([])
ax.set_xticks([])
ax.set_xlabel("CD45")
ax.set_ylabel("CD15")
plt.savefig(output_cd15_cd45, bbox_inches="tight", pad_inches=0, dpi=200)
scanpy.pl.scatter(
adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
x="feat_combined_sum_RPe",
y="feat_combined_sum_APC",
color="leiden",
legend_loc="on data",
)
grid = seaborn.FacetGrid(
data=scanpy.get.obs_df(
adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
keys=["feat_combined_sum_RPe", "feat_combined_sum_APC", "leiden_merged"],
use_raw=True,
),
col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
seaborn.scatterplot(
data=scanpy.get.obs_df(
adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
keys=["feat_combined_sum_RPe", "feat_combined_sum_APC"],
use_raw=True,
),
x="feat_combined_sum_RPe",
y="feat_combined_sum_APC",
color="grey",
s=0.5,
alpha=0.5,
ax=ax,
)
grid.map_dataframe(
seaborn.scatterplot, x="feat_combined_sum_RPe", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
ax.set_yticks([])
ax.set_xticks([])
ax.set_xlabel("Siglec 8")
ax.set_ylabel("CD15")
plt.savefig(output_cd15_siglec8, bbox_inches="tight", pad_inches=0, dpi=200)
X_train, X_test, y_train, y_test = train_test_split(
adata[:, adata.var.selected_corr],
adata.obs["leiden_merged"],
test_size=0.1,
stratify=adata.obs["leiden_merged"],
)
model = RandomForestClassifier(n_estimators=50, random_state=0).fit(
X_train.to_df(), y_train.values
)
preds = model.predict(X_test.to_df())
balanced_accuracy_score(y_test.values, preds)
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test.to_df())
y_train.cat.categories
shap.plots.beeswarm(shap_values[..., 3])
adata.obs["meta_masks"] = adata.obs[["meta_scene", "meta_tile"]].apply(
lambda r: str(data_dir / "masks" / "%s_%s.npy") % (r.meta_scene, r.meta_tile),
axis=1,
)
plot_gate_czi(
sel=adata.obs["leiden"] == "6",
df=adata.obs,
channels=[0, 1, 2, 3, 4, 5, 6],
maxn=50,
masks_path_col="meta_masks",
)
plot_gate_czi(
sel=adata.obs["leiden"] == "6",
df=adata.obs,
channels=[0, 1, 2, 3, 4, 5, 6],
maxn=50,
)
plt.savefig(output_unclassified)
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05, 0.95])
extent = quantiles.loc[
:,
[
"feat_combined_sum_%s" % s
for s in ["DAPI", "EGFP", "RPe", "APC", "Bright", "Oblique", "PGC"]
],
].T.values
plot_gate_czi(
sel=adata.obs["leiden"] == "6",
df=adata.obs,
channels=[0, 1, 2, 3, 4, 5, 6],
maxn=50,
extent=extent,
)
scanpy.pl.violin(adata, "feat_combined_sum_APC", groupby="leiden_merged")
shap.plots.scatter(shap_values[..., "feat_combined_sum_APC", 4])
shap.plots.beeswarm(shap_values[..., 5])
plot_gate_czi(
sel=adata.obs["leiden"] == "9",
df=adata.obs,
channels=[0, 1, 2, 3, 4, 5, 6],
maxn=30,
masks_path_col="meta_masks",
)
# create a dictionary to map cluster to annotation label
cluster2annotation = {
"1": "granulocytes",
"8": "eosinophils",
"4": "monocytes",
"2": "lymphocytes",
"6": "unclassified",
}
# add a new `.obs` column called `cell type` by mapping clusters to annotation using pandas `map` function
cat_type = pandas.CategoricalDtype(
["monocytes", "lymphocytes", "granulocytes", "eosinophils", "unclassified"],
ordered=True,
)
adata.obs["cell type"] = (
adata.obs["leiden_merged"].map(cluster2annotation).astype(cat_type)
)
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
ax = scanpy.pl.matrixplot(
adata,
markers,
groupby="cell type",
dendrogram=False,
vmin=-2,
vmax=2,
cmap="RdBu_r",
ax=axes[1],
show=False,
use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="cell type", ax=axes[2], show=False, palette="tab10")
seaborn.countplot(data=adata.obs, y="cell type", hue="meta_replicate", ax=axes[0])
axes[0].set_title("Cell type counts")
axes[1].set_title("Marker intensity")
axes[2].set_title("UMAP")
axes[0].legend(title="Replicate")
plt.savefig(output_three, bbox_inches="tight", pad_inches=0, dpi=200)
counts = adata.obs["cell type"].value_counts().to_frame()
counts["fraction"] = counts["cell type"] / counts["cell type"].sum()
counts.columns = ["Count", "Fraction"]
print(counts.style.to_latex(hrules=True))
\begin{tabular}{lrr}
\toprule
{} & {Count} & {Fraction} \\
\midrule
granulocytes & 21725 & 0.730989 \\
lymphocytes & 4904 & 0.165007 \\
monocytes & 1737 & 0.058445 \\
unclassified & 1031 & 0.034690 \\
eosinophils & 323 & 0.010868 \\
\bottomrule
\end{tabular}
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05, 0.95])
plot_gate_czi(
sel=adata.obs["cell type"] == "unclassified",
df=adata.obs,
channels=[0,1,2,3,4,5,6],
maxn=40
)
plt.savefig(output_unclassified, bbox_inches="tight")
0 P2-D2 0 P3-D1 0 P3-D5 0 P4-D3 0 P4-D5 0 P5-D4 0 P6-D4 0 P9-D2 0 P9-D5 0 P10-D1 0 P10-D3 0 P12-D1 0 P12-D4 0 P13-D1 0 P13-D3 0 P13-D5 0 P14-D1 0 P14-D5 0 P15-D1 0 P15-D4 0 P17-D2 0 P18-D2 0 P19-D3 0 P19-D4 0 P20-D3 0 P20-D5 0 P21-D3 0 P22-D1 0 P22-D4 0 P22-D5 0 P23-D3 0 P23-D4 0 P24-D1 0 P24-D3 0 P24-D4